%%% Seismic Interpolation using POCS-CNN method.
%%%
%%%
clear; close all
addpath('utilities');
addpath('seismicData');
addpath('seismicData/masks');

%%% choose the original complete data, the default variable in .mat is 'D'
dataChoice = 1;
switch dataChoice
    case 1
        Data = 'hyperbolic-events';
    otherwise
        error('Unexpected choice.');
end
load([Data, '.mat'])
Dataname = Data;


%%% ------------------- Parameters setting -------------------------------
noiseL = 0;                 % noise level, valid range [0, 255],
                            %   non-zero for simultaneous denoising and
                            %   interpolation.
sampleType = 'iregc';       % down-sampling method, valid choice :
                            %   'regc' : regularly down-sampling
                            %   'iregc': irregularly down-sampling
                            %   'randc': ramdonly down-sampling
Ratio= .5;                  % Sampling ratio. valid in [0, 1]. The smaller 
                            %   it is, the less traces will be preserved in
                            %   down-sampling process.
                         
useMaskFile = 1;            % whether to load prepared sampling matrix.
                            %   1 for loading prepared sampling matrix
                            %       generated by program makeMask.m
                            %   0 for generating sampling matrix.
interpMethod = 'shepard';   % pre-interpolating method, valid choice in
                            %    ['', 'shepard', 'nearest', 'linear', 
                            %       'cubic', 'natural']. The '' stands for
                            %    not doing pre-interpolating.
                            %
totalIter  = 30;            % No. of POCS interations.

%%% Two key parameters for CNN-POCS method
lambda1 = 30;               % The upper bound of sigma.
lambda2 = 10;               % The lower bound of sigma. Usually setting 
                            %    lamdba2 = 2 could yield good results in
                            %    noise free cases. It needs to be
                            %    fine-tuned for noisy data.
                            %    
%%% some other parameters setting for result visualizing and saving out.
dx = 10;
dt = 0.004;
freqThresh = 50;
showResult = 1;
saveResult = 0;
saveSnr = 0;
useGPU = 0;
%%% -----------------------------------------------------------------------

%%% First of all, We have to cast the orignal data into value
%%% range of [0, 1].
label = single(D);
[m, n]              = size(label);
% normalize to [0, 1]
xmin = min(label(:));
label = label - xmin;
xmax = max(label(:));
label = label/xmax;
nlabel = label + single(noiseL/255*randn(size(label)));

%%% Before interpolating, we do down-sampling the original data to get 
%%% the down-sampled data. 
if useMaskFile
    load(['mask',num2str(m),'x',num2str(n),sampleType,num2str(fix(Ratio*100)),'.mat']);
else
    mask = projMask(D, Ratio, sampleType);
end
input = nlabel.*mask;
input(mask==0) = mean(nlabel(:));

SNRinput = CalSNR(D, input*xmax+xmin);
PSNRinput = Psnr(D, input*xmax+xmin);
disp(['Input SNR: ', num2str(SNRinput), ' PSNR: ', num2str(PSNRinput)]);

%%% Pre-interpolating the data if it is demanded.
if ~strcmp(interpMethod, 'shepard') && ~strcmp(interpMethod, '')
    initD = initInterp(input, mask, interpMethod);
else
    if strcmp(interpMethod, 'shepard')
    window = 10;    % default 10, from [5, 30]
    initD = shepard_initialize(input, mask, window);
    else
        initD = input;
    end
end

SNRinitD = CalSNR(D, initD*xmax+xmin);
PSNRinitD = Psnr(D, initD*xmax+xmin);
disp(['Pre-iterpolated SNR: ', num2str(SNRinitD), ' PSNR: ', num2str(PSNRinitD)]);

%%% Now it's time to interpolating the down-sampled data 'input' using the
%%% CNN-POCS method.
inIter      = 1;
SigmaS = (lambda2/lambda1).^((0:totalIter-1)/(totalIter-1))*lambda1;
ns          = min(25,max(ceil(SigmaS/2),1));
ns          = [ns(1)-1,ns];
snrs = zeros(1, totalIter);
folderModel = 'models';
load(fullfile(folderModel,'model.mat'));

output = single(initD);
input = single(input);
mask = single(mask);

if useGPU
    input = gpuArray(input);
    mask = gpuArray(mask);
    output = gpuaArray(output);
end

tic; cput = cputime;
for itern = 1 : totalIter
    output = (1 - mask).*output + mask.*input;
    if ns(itern+1) ~= ns(itern)
        [net] = loadmodel(SigmaS(itern), CNNdenoiser);
        net = vl_simplenn_tidy(net);
        if useGPU
            net = vl_simplenn_move(net, 'gpu');
        end
    end
    for k = 1 : inIter
        res    = vl_simplenn(net,output,[],[],'conserveMemory',true,'mode','test');
        output = output - res(end).x;
    end
    snrs(itern) = CalSNR(D, output*xmax+xmin);
end

if useGPU
    input = gather(input);
    mask = gather(mask);
    output = gather(output);
end
toc;
disp(['CPU time: ', num2str(cputime-cput)]);

pocscnnRecon = output*xmax + xmin;
SNRCur = CalSNR(D, pocscnnRecon);
PSNRCur = Psnr(D, pocscnnRecon);
disp(['CNN-POCS SNR: ', num2str(SNRCur), ' PSNR: ', num2str(PSNRCur)]);


%%% Finally we visualize the results.
if showResult
    x = (0:m-1)*dx; t = (0:n-1)*dt;
    fig1 = figure(1); set(gcf, 'color', 'white'), set(gcf, 'Position', [100, 100, 900, 700]), colormap(gray);
    sub1 = subplot(221);
    imagesc(x,t,D), %caxis([0,1]); cb1 = colorbar('Xtick', 0:0.1:1);%setColorbar(sub1, cb1, -0.02, 0.02, 0.01, 0.3);  axis off;
    xlabel('Distance (m)'); ylabel('Time (s)');
    title('Original Data')
    sub2 = subplot(222);
    imagesc(x,t,input),%caxis([0,1]); cb2 = colorbar('Xtick', 0:0.1:1); %setColorbar(sub2, cb2, -0.02, 0.02, 0.01, 0.3);  axis off;
    xlabel('Distance (m)'); ylabel('Time (s)');
    title([num2str(fix(Ratio*100)), '% subsampled data'])
    sub3 = subplot(223);
    imagesc(x,t,pocscnnRecon),%caxis([0,1]); cb3 = colorbar('Xtick', 0:0.1:1); %setColorbar(sub3, cb3, -0.02, 0.02, 0.01, 0.3); axis off;
    xlabel('Distance (m)'); ylabel('Time (s)');
    title(['Reconstructed data,', ' SNR ', num2str(SNRCur, '%2.2f'), 'dB'])
    sub4 = subplot(224);
    imagesc(x,t,D-pocscnnRecon); %cb4 = colorbar('Xtick', 0:0.1:1); %setColorbar(sub4, cb4, -0.02, 0.02, 0.01, 0.3); axis off;
    xlabel('Distance (m)'); ylabel('Time (s)');
    title('Reconstrunction error')
    drawnow;  
    
    figure(2); plot(snrs);
    
    fig3 = figure(3); set(gcf, 'color', 'white'), set(gcf, 'Position', [100, 100, 900, 700]), colormap(jet);
    sub21 = subplot(131);
    [wn, k, f] = waveNumFreq(D, dx, dt);
    index = find(f>=0 & f<=freqThresh);
    imagesc(k/max(abs(k))/2, f(f>=0 & f<=freqThresh), log10(1+abs(wn(index, :))))
    xlabel('Normalized Wavenumber'); ylabel('Frequency (Hz)');
    set(gca, 'xlim', [-0.5, 0.5]);
    set(gca, 'xtick', [-0.5:0.5:0.5]); 
    
     sub22 = subplot(132);
    [wn, k, f] = waveNumFreq(input, dx, dt);
    index = find(f>=0 & f<=freqThresh);
    imagesc(k/max(abs(k))/2, f(f>=0 & f<=freqThresh), log10(1+abs(wn(index, :))))
    xlabel('Normalized Wavenumber'); ylabel('Frequency');
    set(gca, 'xlim', [-0.5, 0.5]);
    set(gca, 'xtick', [-0.5:0.5:0.5]); 
    
    sub23 = subplot(133);
    [wn, k, f] = waveNumFreq(pocscnnRecon, dx, dt);
    index = find(f>=0 & f<=freqThresh);
    imagesc(k/max(abs(k))/2, f(f>=0 & f<=freqThresh), log10(1+abs(wn(index, :))))
    xlabel('Normalized Wavenumber'); ylabel('Frequency (Hz)');
    set(gca, 'xlim', [-0.5, 0.5]);
    set(gca, 'xtick', [-0.5:0.5:0.5]);    
end

if saveResult
    save(['seismicResult/pocscnn/results/', Dataname, '-pocscnn-', sampleType,...
        num2str(fix(Ratio*100)), '.mat'], 'pocscnnRecon');
end
if saveSnr
    save(['seismicResult/pocscnn/snrs/', Dataname, '-', sampleType, ...
        num2str(fix(Ratio*100)), '-lambda-', num2str(lambda1), '-', num2str(lambda2), '.mat'], 'snrs');
end